/*
 * Copyright (c) 2022 Huawei Device Co., Ltd.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
import Unsafe from './Unsafe';
import FseCompressionTable from './FseCompressionTable';
import Constants from './Constants';
import { BitInputStream, Initializer, Loader } from './BitInputStream'
import NumberTransform from './NumberTransform'
import Util from './Util'
import BitOutputStream from './BitOutputStream'
import Exception from '../util/Exception'
import Long from '../util/long/index'

export class FiniteStateEntropy {
    public static MAX_SYMBOL: number = 255;
    public static MAX_TABLE_LOG: number = 12;
    public static MIN_TABLE_LOG: number = 5;
    private static REST_TO_BEAT: Int32Array = new Int32Array([0, 473195, 504333, 520860, 550000, 700000, 750000, 830000]);
    private static UNASSIGNED: number = -2;

    constructor() {
    }

    public static decompress(table: Table, inputBase: any, inputAddress: Long, inputLimit: Long, outputBuffer: Int8Array): number {
        let outputBase: any = outputBuffer;
        let outputAddress: Long = Long.fromNumber(Unsafe.ARRAY_BYTE_BASE_OFFSET);
        let outputLimit: Long = outputAddress.add(outputBuffer.length);

        let input: Long = inputAddress;
        let output: Long = outputAddress;

        // initialize bit stream
        let initializer: Initializer = new Initializer(inputBase, input, inputLimit);
        initializer.initialize();
        let bitsConsumed: number = initializer.getBitsConsumed();
        let currentAddress: Long = initializer.getCurrentAddress();
        let bits: Long = initializer.getBits();

        // initialize first FSE stream
        let state1: number = BitInputStream.peekBits(bitsConsumed, bits, table.log2Size).toInt();
        bitsConsumed += table.log2Size;

        let loader: Loader = new Loader(inputBase, input, currentAddress, bits, bitsConsumed);
        loader.load();
        bits = loader.getBits();
        bitsConsumed = loader.getBitsConsumed();
        currentAddress = loader.getCurrentAddress();

        let state2: number = BitInputStream.peekBits(bitsConsumed, bits, table.log2Size).toInt();
        bitsConsumed += table.log2Size;

        loader = new Loader(inputBase, input, currentAddress, bits, bitsConsumed);
        loader.load();
        bits = loader.getBits();
        bitsConsumed = loader.getBitsConsumed();
        currentAddress = loader.getCurrentAddress();

        let symbols: Int8Array = table.symbols;
        let numbersOfBits: Int8Array = table.numberOfBits;
        let newStates: Int32Array = table.newState;

        while (output.lessThanOrEqual(outputLimit.sub(4))) {
            let numberOfBits: number;

            Unsafe.putByte(outputBase, output.toNumber(), symbols[state1]);
            numberOfBits = numbersOfBits[state1];
            state1 = newStates[state1] + BitInputStream.peekBits(bitsConsumed, bits, numberOfBits).toInt();
            bitsConsumed += numberOfBits;

            Unsafe.putByte(outputBase, output.add(1).toNumber(), symbols[state2]);
            numberOfBits = numbersOfBits[state2];
            state2 = newStates[state2] + BitInputStream.peekBits(bitsConsumed, bits, numberOfBits).toInt();
            bitsConsumed += numberOfBits;

            Unsafe.putByte(outputBase, output.add(2).toNumber(), symbols[state1]);
            numberOfBits = numbersOfBits[state1];
            state1 = newStates[state1] + BitInputStream.peekBits(bitsConsumed, bits, numberOfBits).toInt();
            bitsConsumed += numberOfBits;

            Unsafe.putByte(outputBase, output.add(3).toNumber(), symbols[state2]);
            numberOfBits = numbersOfBits[state2];
            state2 = newStates[state2] + BitInputStream.peekBits(bitsConsumed, bits, numberOfBits).toInt();
            bitsConsumed += numberOfBits;

            output = output.add(Constants.SIZE_OF_INT);

            loader = new Loader(inputBase, input, currentAddress, bits, bitsConsumed);
            let done: boolean = loader.load();
            bitsConsumed = loader.getBitsConsumed();
            bits = loader.getBits();
            currentAddress = loader.getCurrentAddress();
            if (done) {
                break;
            }
        }

        while (true) {
            Util.verify(output.lessThanOrEqual(outputLimit.sub(2)), input, "Output buffer is too small");
            Unsafe.putByte(outputBase, output.toNumber(), symbols[state1]);
            output = output.add(1)
            let numberOfBits: number = numbersOfBits[state1];
            state1 = newStates[state1] + BitInputStream.peekBits(bitsConsumed, bits, numberOfBits).toInt();
            bitsConsumed += numberOfBits;

            loader = new Loader(inputBase, input, currentAddress, bits, bitsConsumed);
            loader.load();
            bitsConsumed = loader.getBitsConsumed();
            bits = loader.getBits();
            currentAddress = loader.getCurrentAddress();

            if (loader.isOverflow()) {
                Unsafe.putByte(outputBase, output.toNumber(), symbols[state2]);
                output = output.add(1)
                break;
            }

            Util.verify(output.lessThanOrEqual(outputLimit.sub(2)), input, "Output buffer is too small");
            Unsafe.putByte(outputBase, output.toNumber(), symbols[state2]);
            output = output.add(1)
            let numberOfBits1: number = numbersOfBits[state2];
            state2 = newStates[state2] + BitInputStream.peekBits(bitsConsumed, bits, numberOfBits1).toInt();
            bitsConsumed += numberOfBits1;

            loader = new Loader(inputBase, input, currentAddress, bits, bitsConsumed);
            loader.load();
            bitsConsumed = loader.getBitsConsumed();
            bits = loader.getBits();
            currentAddress = loader.getCurrentAddress();

            if (loader.isOverflow()) {
                Unsafe.putByte(outputBase, output.toNumber(), symbols[state1]);
                output = output.add(1)
                break;
            }
        }

        return output.sub(outputAddress).toInt();
    }

    public static compress(outputBase: any, outputAddress: Long, outputSize: number, input: Int8Array, inputSize: number, table: FseCompressionTable): number {
        return FiniteStateEntropy.compressUnsafe(outputBase, outputAddress, outputSize, input, Long.fromNumber(Unsafe.ARRAY_BYTE_BASE_OFFSET), inputSize, table);
    }

    public static compressUnsafe(outputBase: any, outputAddress: Long, outputSize: number, inputBase: any, inputAddress: Long, inputSize: number, table: FseCompressionTable): number {
        Util.checkArgument(outputSize >= Constants.SIZE_OF_LONG, "Output buffer too small");

        let start: Long = inputAddress;
        let inputLimit: Long = start.add(inputSize);

        let input: Long = inputLimit;

        if (inputSize <= 2) {
            return 0;
        }

        let stream: BitOutputStream = new BitOutputStream(outputBase, outputAddress, outputSize);

        let state1: number;
        let state2: number;

        if ((inputSize & 1) != 0) {
            input = input.sub(1);
            state1 = table.begin(Unsafe.getByte(inputBase, input.toNumber()));

            input = input.sub(1);
            state2 = table.begin(Unsafe.getByte(inputBase, input.toNumber()));

            input = input.sub(1);
            state1 = table.encode(stream, state1, Unsafe.getByte(inputBase, input.toNumber()));

            stream.flush();
        }
        else {
            input = input.sub(1);
            state2 = table.begin(Unsafe.getByte(inputBase, input.toNumber()));

            input = input.sub(1);
            state1 = table.begin(Unsafe.getByte(inputBase, input.toNumber()));
        }

        // join to mod 4
        inputSize -= 2;

        if ((Constants.SIZE_OF_LONG * 8 > FiniteStateEntropy.MAX_TABLE_LOG * 4 + 7) && (inputSize & 2) != 0) { /* test bit 2 */
            input = input.sub(1);
            state2 = table.encode(stream, state2, Unsafe.getByte(inputBase, input.toNumber()));

            input = input.sub(1);
            state1 = table.encode(stream, state1, Unsafe.getByte(inputBase, input.toNumber()));

            stream.flush();
        }

        // 2 or 4 encoding per loop
        while (input.greaterThan(start)) {
            input = input.sub(1);
            state2 = table.encode(stream, state2, Unsafe.getByte(inputBase, input.toNumber()));

            if (Constants.SIZE_OF_LONG * 8 < FiniteStateEntropy.MAX_TABLE_LOG * 2 + 7) {
                stream.flush();
            }

            input = input.sub(1);
            state1 = table.encode(stream, state1, Unsafe.getByte(inputBase, input.toNumber()));

            if (Constants.SIZE_OF_LONG * 8 > FiniteStateEntropy.MAX_TABLE_LOG * 4 + 7) {
                input = input.sub(1);
                state2 = table.encode(stream, state2, Unsafe.getByte(inputBase, input.toNumber()));

                input = input.sub(1);
                state1 = table.encode(stream, state1, Unsafe.getByte(inputBase, input.toNumber()));
            }

            stream.flush();
        }

        table.finish(stream, state2);
        table.finish(stream, state1);

        return stream.close();
    }

    public static optimalTableLog(maxTableLog: number, inputSize: number, maxSymbol: number): number
    {
        if (inputSize <= 1) {
            throw new Exception(); // not supported. Use RLE instead
        }

        let result: number = maxTableLog;

        result = Math.min(result, Util.highestBit((inputSize - 1)) - 2); // we may be able to reduce accuracy if input is small

        // Need a minimum to safely represent all symbol values
        result = Math.max(result, Util.minTableLog(inputSize, maxSymbol));

        result = Math.max(result, FiniteStateEntropy.MIN_TABLE_LOG);
        result = Math.min(result, FiniteStateEntropy.MAX_TABLE_LOG);

        return result;
    }

    public static normalizeCounts(normalizedCounts: Int16Array, tableLog: number, counts: Int32Array, total: number, maxSymbol: number): number
    {
        Util.checkArgument(tableLog >= FiniteStateEntropy.MIN_TABLE_LOG, "Unsupported FSE table size");
        Util.checkArgument(tableLog <= FiniteStateEntropy.MAX_TABLE_LOG, "FSE table size too large");
        Util.checkArgument(tableLog >= Util.minTableLog(total, maxSymbol), "FSE table size too small");

        let scale: Long = Long.fromNumber(62 - tableLog);
        let step: Long = Long.fromNumber(1).shiftLeft(62).divide(total); //(1 << 62) / total
        let vstep: Long = Long.fromNumber(1).shiftLeft(scale.sub(20));

        let stillToDistribute: number = 1 << tableLog;

        let largest: number = 0;
        let largestProbability: number = 0;
        let lowThreshold: number = total >>> tableLog;

        for (let symbolnum: number = 0; symbolnum <= maxSymbol; symbolnum++) {
            if (counts[symbolnum] == total) {
                throw new Exception(); // TODO: should have been RLE-compressed by upper layers
            }
            if (counts[symbolnum] == 0) {
                normalizedCounts[symbolnum] = 0;
                continue;
            }
            if (counts[symbolnum] <= lowThreshold) {
                normalizedCounts[symbolnum] = -1;
                stillToDistribute--;
            } else {
                let probability: number = NumberTransform.toShort(
                (step.multiply(counts[symbolnum])).shiftRightUnsigned(scale)
                    .toInt());
                if (probability < 8) {
                    let restToBeat: Long = vstep.multiply(FiniteStateEntropy.REST_TO_BEAT[probability]); //需要注意乘积后获取是否正确
                    let delta: Long = step.multiply(counts[symbolnum]).sub(Long.fromNumber(probability).shiftLeft(scale));
                    if (delta.greaterThan(restToBeat)) {
                        probability++;
                    }
                }
                if (probability > largestProbability) {
                    largestProbability = probability;
                    largest = symbolnum;
                }
                normalizedCounts[symbolnum] = probability;
                stillToDistribute -= probability;
            }
        }

        if (-stillToDistribute >= (normalizedCounts[largest] >>> 1)) {
            FiniteStateEntropy.normalizeCounts2(normalizedCounts, tableLog, counts, total, maxSymbol);
        } else {
            normalizedCounts[largest] += NumberTransform.toShort(stillToDistribute);
        }

        return tableLog;
    }

    private static normalizeCounts2(normalizedCounts: Int16Array, tableLog: number, counts: Int32Array, total: number, maxSymbol: number): number {
        let distributed: number = 0;

        let lowThreshold: number = total >>> tableLog; // minimum count below which frequency in the normalized table is "too small" (~ < 1)
        let lowOne: number = (total * 3) >>> (tableLog + 1); // 1.5 * lowThreshold. If count in (lowThreshold, lowOne] => assign frequency 1

        for (let i: number = 0; i <= maxSymbol; i++) {
            if (counts[i] == 0) {
                normalizedCounts[i] = 0;
            } else if (counts[i] <= lowThreshold) {
                normalizedCounts[i] = -1;
                distributed++;
                total -= counts[i];
            }
            else if (counts[i] <= lowOne) {
                normalizedCounts[i] = 1;
                distributed++;
                total -= counts[i];
            }
            else {
                normalizedCounts[i] = FiniteStateEntropy.UNASSIGNED;
            }
        }

        let normalizationFactor: number = 1 << tableLog;
        let toDistribute: number = normalizationFactor - distributed;

        if ((total / toDistribute) > lowOne) {
            lowOne = ((total * 3) / (toDistribute * 2));
            for (let i: number = 0; i <= maxSymbol; i++) {
                if ((normalizedCounts[i] == FiniteStateEntropy.UNASSIGNED) && (counts[i] <= lowOne)) {
                    normalizedCounts[i] = 1;
                    distributed++;
                    total -= counts[i];
                }
            }
            toDistribute = normalizationFactor - distributed;
        }

        if (distributed == maxSymbol + 1) {
            let maxValue: number = 0;
            let maxCount: number = 0;
            for (let i: number = 0; i <= maxSymbol; i++) {
                if (counts[i] > maxCount) {
                    maxValue = i;
                    maxCount = counts[i];
                }
            }
            normalizedCounts[maxValue] += NumberTransform.toShort(toDistribute);
            return 0;
        }

        if (total == 0) {
            for (let i: number = 0; toDistribute > 0; i = (i + 1) % (maxSymbol + 1)) {
                if (normalizedCounts[i] > 0) {
                    toDistribute--;
                    normalizedCounts[i]++;
                }
            }
            return 0;
        }

        // TODO: simplify/document this code
        let vStepLog: Long = Long.fromNumber(62 - tableLog);
        let mid: Long = Long.fromNumber(1).shiftLeft((vStepLog.sub(1))).sub(1);
        let rStep: Long = ((Long.fromNumber(1)
            .shiftLeft(vStepLog)
            .multiply(toDistribute)).add(mid)).divide(total); /* scale on remaining */
        let tmpTotal: Long = mid;
        for (let i: number = 0; i <= maxSymbol; i++) {
            if (normalizedCounts[i] == FiniteStateEntropy.UNASSIGNED) {
                let end: Long = tmpTotal.add((rStep.multiply(counts[i])));
                let sStart: number = tmpTotal.shiftRightUnsigned(vStepLog).toInt();
                let sEnd: number = end.shiftRightUnsigned(vStepLog).toInt();
                let weight: number = sEnd - sStart;

                if (weight < 1) {
                    throw new Error();
                }
                normalizedCounts[i] = NumberTransform.toShort(weight);
                tmpTotal = end;
            }
        }

        return 0;
    }

    public static writeNormalizedCounts(outputBase: any, outputAddress: Long, outputSize: number, normalizedCounts: Int16Array, maxSymbol: number, tableLog: number): number
    {
        Util.checkArgument(tableLog <= FiniteStateEntropy.MAX_TABLE_LOG, "FSE table too large");
        Util.checkArgument(tableLog >= FiniteStateEntropy.MIN_TABLE_LOG, "FSE table too small");

        let output: Long = outputAddress;
        let outputLimit: Long = outputAddress.add(outputSize);

        let tableSize: number = 1 << tableLog;

        let bitCount: number = 0;

        // encode table size
        let bitStream: number = (tableLog - FiniteStateEntropy.MIN_TABLE_LOG);
        bitCount += 4;

        let remaining: number = tableSize + 1; // +1 for extra accuracy
        let threshold: number = tableSize;
        let tableBitCount: number = tableLog + 1;

        let symbolnum: number = 0;

        let previousIs0: boolean = false;
        while (remaining > 1) {
            if (previousIs0) {
                let start: number = symbolnum;

                while (normalizedCounts[symbolnum] == 0) {
                    symbolnum++;
                }

                while (symbolnum >= start + 24) {
                    start += 24;
                    bitStream |= (65535 << bitCount);
                    Util.checkArgument(
                    output.add(Constants.SIZE_OF_SHORT)
                        .lessThanOrEqual(outputLimit), "Output buffer too small");

                    Unsafe.putShort(outputBase, output.toNumber(), NumberTransform.toShort(bitStream));
                    output = output.add(Constants.SIZE_OF_SHORT);

                    bitStream >>>= 16;
                }
                while (symbolnum >= start + 3) {
                    start += 3;
                    bitStream |= 3 << bitCount;
                    bitCount += 2;
                }

                bitStream |= (symbolnum - start) << bitCount;
                bitCount += 2;

                if (bitCount > 16) {
                    Util.checkArgument(
                    output.add(Constants.SIZE_OF_SHORT)
                        .lessThanOrEqual(outputLimit), "Output buffer too small");

                    Unsafe.putShort(outputBase, output.toNumber(), NumberTransform.toShort(bitStream));
                    output = output.add(Constants.SIZE_OF_SHORT);

                    bitStream >>>= 16;
                    bitCount -= 16;
                }
            }

            let count: number = normalizedCounts[symbolnum++];
            let max: number = (2 * threshold - 1) - remaining;
            remaining -= count < 0 ? -count : count;
            count++; /* +1 for extra accuracy */
            if (count >= threshold) {
                count += max;
            }
            bitStream |= count << bitCount;
            bitCount += tableBitCount;
            bitCount -= (count < max ? 1 : 0);
            previousIs0 = (count == 1);

            if (remaining < 1) {
                throw new Error();
            }

            while (remaining < threshold) {
                tableBitCount--;
                threshold >>= 1;
            }

            if (bitCount > 16) {
                Util.checkArgument(output.add(Constants.SIZE_OF_SHORT).lessThanOrEqual(outputLimit), "Output buffer too small");

                Unsafe.putShort(outputBase, output.toNumber(), NumberTransform.toShort(bitStream));
                output = output.add(Constants.SIZE_OF_SHORT);

                bitStream >>>= 16;
                bitCount -= 16;
            }
        }

        Util.checkArgument(output.add(Constants.SIZE_OF_SHORT).lessThanOrEqual(outputLimit), "Output buffer too small");
        Unsafe.putShort(outputBase, output.toNumber(), NumberTransform.toShort(bitStream));
        output = output.add((bitCount + 7) / 8);

        Util.checkArgument(symbolnum <= maxSymbol + 1, "Error"); // TODO

        return output.sub(outputAddress).toInt();
    }
}

export class Table {
    log2Size: number;
    newState: Int32Array;
    symbols: Int8Array;
    numberOfBits: Int8Array;

    public static getTables(log2Capacity: number) {
        let table = new Table();
        let capacity: number = 1 << log2Capacity;
        table.newState = new Int32Array(capacity);
        table.symbols = new Int8Array(capacity);
        table.numberOfBits = new Int8Array(capacity);
        return table;
    }

    public static getTable(log2Size: number, newState: Int32Array, symbols: Int8Array, numberOfBits: Int8Array) {
        let table = new Table();
        let size: number = 1 << log2Size;
        if (newState.length != size || symbols.length != size || numberOfBits.length != size) {
            throw new Exception("Expected arrays to match provided size");
        }

        table.log2Size = log2Size;
        table.newState = newState;
        table.symbols = symbols;
        table.numberOfBits = numberOfBits;
        return table;
    }
}
